import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os

from sklearn.model_selection import KFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

from sklearn.metrics import accuracy_score, r2_score
from sklearn import linear_model
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"


def read_signal(signal_path):
    with open(signal_path) as f:
        lines = f.readlines()

    return np.array([line[:-1] for line in lines], dtype = np.float32)

    
class SimCLR_DataLoader(Dataset):
    
    def __init__(self, csv, method, l = 10,  fs = 100, std_s1=0.12636206, std_s2=0.14737946):
        
        self.std_s1 = std_s1
        self.std_s2 = std_s2
        self.csv = csv
        self.l = l
        self.fs = fs
        self.method = method
        
    def __len__(self):
        return len(self.csv)

    def __getitem__(self, ix):
        pat_path, strip, tmp_study = self.csv[ix, :]
        strip_path = os.path.join(pat_path, str(strip))

        tmp_study = int(tmp_study)
        tmp_std = self.std_s1 if tmp_study == 1 else self.std_s2
        tmp_strip = np.load(strip_path)

        if self.method == "clocs":
            x1, x2 = self.get_strips_clocs(tmp_strip)

        else:
            x1, x2 = self.get_strips(tmp_strip, pat_path, tmp_study)
            x1 = x1 / tmp_std
            x2 = x2 / tmp_std
        
            x1 = torch.tensor(x1).float()
            x2 = torch.tensor(x2).float()
        
        
       
      
        if self.method == "clocs":
            x1 = x1 / tmp_std
            x2 = x2 / tmp_std

            x1 = x1 + np.random.normal(0, 0.01, size=(x1.shape[0]))
            x2 = x2 + np.random.normal(0, 0.01, size=(x2.shape[0]))
        
            x1 = torch.tensor(x1).float()
            x2 = torch.tensor(x2).float()

            
            if np.random.uniform() > 0.5:
                x1 = torch.flip(x1, dims=[0])
            
            else:
                x1 = -x1

            if np.random.uniform() > 0.5:
                x2 = torch.flip(x2, dims=[0])
            
            else:
                x2 = -x2
        
      
            
        return x1.to(device), x2.to(device)
        
    def get_strips(self, tmp_strip, pat_path, tmp_study):
        delay = tmp_strip.shape[0] -  self.l * self.fs 
        
        ini_idx = np.random.randint(0,  delay)
        x1 = tmp_strip[ini_idx : ini_idx + self.l * self.fs]
        
        other_study = "2" if tmp_study == 1 else "1"
        other_pat_path = pat_path[:10] + other_study + pat_path[10 + 1:]
        strip_2 = np.random.choice(os.listdir(other_pat_path), 1)[0]
    
        strip_2_path = os.path.join(other_pat_path, strip_2)
        tmp_strip_2 = np.load(strip_2_path)
    
        delay_2 = tmp_strip_2.shape[0] - self.l * self.fs
        ini_idx_2 = np.random.randint(0,  delay_2)
        x2 = tmp_strip_2[ini_idx_2 : ini_idx_2 + self.l * self.fs]

        
        return x1, x2
    
    def get_strips_clocs(self, tmp_strip):
        delay = tmp_strip.shape[0] -  2 * self.l * self.fs 
 
        ini_idx = np.random.randint(0,  delay)
        x1 = tmp_strip[ini_idx : ini_idx + self.l * self.fs]
        x2 = tmp_strip[ini_idx + self.l * self.fs: ini_idx + 2 * self.l * self.fs]

        return x1, x2
    

def train_batch(batch, model, optimizer):
    
    model.train()
    optimizer.zero_grad()
    tmp_x1, tmp_x2 = batch
    
    loss = model(tmp_x1, tmp_x2)
    loss.backward()
    optimizer.step()
    return loss.item()

